Source code for nlp_architect.nn.tensorflow.python.keras.layers.crf

# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
import tensorflow as tf

[docs]class CRF(tf.keras.layers.Layer): """ Conditional Random Field layer (tf.keras) `CRF` can be used as the last layer in a network (as a classifier). Input shape (features) must be equal to the number of classes the CRF can predict (a linear layer is recommended). Note: the loss and accuracy functions of networks using `CRF` must use the provided loss and accuracy functions (denoted as loss and viterbi_accuracy) as the classification of sequences are used with the layers internal weights. Args: num_labels (int): the number of labels to tag each temporal input. Input shape: nD tensor with shape `(batch_size, sentence length, num_classes)`. Output shape: nD tensor with shape: `(batch_size, sentence length, num_classes)`. """ def __init__(self, num_classes, **kwargs): self.transitions = None super(CRF, self).__init__(**kwargs) # num of output labels self.output_dim = int(num_classes) self.input_spec = tf.keras.layers.InputSpec(min_ndim=3) self.supports_masking = False self.sequence_lengths = None
[docs] def get_config(self): config = { "output_dim": self.output_dim, "supports_masking": self.supports_masking, "transitions": tf.keras.backend.eval(self.transitions), } base_config = super(CRF, self).get_config() return dict(list(base_config.items()) + list(config.items()))
[docs] def build(self, input_shape): assert len(input_shape) == 3 f_shape = tf.TensorShape(input_shape) input_spec = tf.keras.layers.InputSpec(min_ndim=3, axes={-1: f_shape[-1]}) if f_shape[-1] is None: raise ValueError( "The last dimension of the inputs to `CRF` " "should be defined. Found `None`." ) if f_shape[-1] != self.output_dim: raise ValueError( "The last dimension of the input shape must be equal to output" " shape. Use a linear layer if needed." ) self.input_spec = input_spec self.transitions = self.add_weight( name="transitions", shape=[self.output_dim, self.output_dim], initializer="glorot_uniform", trainable=True, ) self.built = True
# pylint: disable=arguments-differ
[docs] def call(self, inputs, sequence_lengths=None, **kwargs): sequences = tf.convert_to_tensor(inputs, dtype=self.dtype) if sequence_lengths is not None: assert len(sequence_lengths.shape) == 2 assert tf.convert_to_tensor(sequence_lengths).dtype == "int32" seq_len_shape = tf.convert_to_tensor(sequence_lengths).get_shape().as_list() assert seq_len_shape[1] == 1 self.sequence_lengths = tf.keras.backend.flatten(sequence_lengths) else: self.sequence_lengths = tf.ones(tf.shape(inputs)[0], dtype=tf.int32) * ( tf.shape(inputs)[1] ) viterbi_sequence, _ = tf.contrib.crf.crf_decode( sequences, self.transitions, self.sequence_lengths ) output = tf.keras.backend.one_hot(viterbi_sequence, self.output_dim) return tf.keras.backend.in_train_phase(sequences, output)
[docs] def loss(self, y_true, y_pred): y_pred = tf.convert_to_tensor(y_pred, dtype=self.dtype) log_likelihood, self.transitions = tf.contrib.crf.crf_log_likelihood( y_pred, tf.cast(tf.keras.backend.argmax(y_true), dtype=tf.int32), self.sequence_lengths, transition_params=self.transitions, ) return tf.reduce_mean(-log_likelihood)
[docs] def compute_output_shape(self, input_shape): tf.TensorShape(input_shape).assert_has_rank(3) return input_shape[:2] + (self.output_dim,)
@property def viterbi_accuracy(self): def accuracy(y_true, y_pred): shape = tf.shape(y_pred) sequence_lengths = tf.ones(shape[0], dtype=tf.int32) * (shape[1]) viterbi_sequence, _ = tf.contrib.crf.crf_decode( y_pred, self.transitions, sequence_lengths ) output = tf.keras.backend.one_hot(viterbi_sequence, self.output_dim) return tf.keras.metrics.categorical_accuracy(y_true, output) accuracy.func_name = "viterbi_accuracy" return accuracy